{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Process the generated queries and code data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import sys\n",
    "import numpy as np\n",
    "import json \n",
    "import tqdm\n",
    "import re \n",
    "\n",
    "sys.path.append(\"../../\") # add package root to path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Process queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_queries(task_queries_file):\n",
    "    \"\"\"\n",
    "    Load task queries from txt file. The first line is the world context, and the rest are task queries line by line:\n",
    "    ''' \n",
    "    \n",
    "    objects = [table, cabinet, cabinet.drawer0, cabinet.drawer1, cabinet.drawer2, cabinet.drawer3, panda_robot] ; # open cabinet.drawer0\n",
    "    ...\n",
    "    '''\n",
    "    \"\"\"\n",
    "    with open(task_queries_file, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "    \n",
    "    # use regex to extract the query in each line:\n",
    "    # objects = [table, cabinet, cabinet.drawer0, cabinet.drawer1, cabinet.drawer2, cabinet.drawer3, panda_robot] ; # open cabinet.drawer0\n",
    "\n",
    "    valid_line_pattern = re.compile(r'(?P<context>objects.*);\\s*#(?P<query>.*)')\n",
    "    task_queries = []\n",
    "    for line in lines:\n",
    "        match = valid_line_pattern.match(line)\n",
    "        if match:\n",
    "            context = match.group('context')\n",
    "            query = match.group('query')\n",
    "            task_query = context + \"; #\" + query\n",
    "            task_queries.append(task_query)\n",
    "\n",
    "    return task_queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "corrupted_task_query_files: ['table_cabinet_5.txt', 'table_cabinet_7.txt', 'table_cabinet_8.txt', 'table_cabinet_12.txt', 'table_cabinet_18.txt', 'table_cabinet_19.txt', 'table_cabinet_20.txt', 'table_cabinet_22.txt', 'table_cabinet_23.txt', 'table_cabinet_24.txt', 'table_cabinet_25.txt', 'table_cabinet_27.txt', 'table_cabinet_32.txt', 'table_cabinet_35.txt', 'table_cabinet_36.txt', 'table_cabinet_37.txt', 'table_cabinet_40.txt', 'table_cabinet_43.txt', 'table_cabinet_53.txt', 'table_cabinet_59.txt', 'table_cabinet_61.txt', 'table_cabinet_64.txt', 'table_cabinet_68.txt', 'table_cabinet_71.txt', 'table_cabinet_73.txt', 'table_cabinet_75.txt', 'table_cabinet_77.txt', 'table_cabinet_78.txt', 'table_cabinet_79.txt', 'table_cabinet_80.txt', 'table_cabinet_81.txt', 'table_cabinet_82.txt', 'table_cabinet_85.txt', 'table_cabinet_88.txt', 'table_cabinet_89.txt', 'table_cabinet_90.txt', 'table_cabinet_92.txt', 'table_cabinet_96.txt', 'table_cabinet_98.txt']\n",
      "Number of corrupted task query files: 39\n"
     ]
    }
   ],
   "source": [
    "task_query_dir = \"../../data/task_queries/\"\n",
    "\n",
    "# get number of valid queries in each task query file\n",
    "num_task_queries = {}\n",
    "for i in range(0, 99):\n",
    "    task_query_file = f\"table_cabinet_{i}.txt\"\n",
    "    task_query_path = os.path.join(task_query_dir, task_query_file)\n",
    "    if not os.path.exists(task_query_path):\n",
    "        continue\n",
    "    task_queries = load_queries(task_query_path)\n",
    "    num_task_queries[task_query_file] = len(task_queries)\n",
    "\n",
    "# get all task queries files without queries from num_task_queries\n",
    "corrupted_task_query_files = [k for k, v in num_task_queries.items() if v == 0]\n",
    "print(f\"corrupted_task_query_files: {corrupted_task_query_files}\")\n",
    "print(f\"Number of corrupted task query files: {len(corrupted_task_query_files)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_task_queries"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Process code"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Code generation stopped at table_cabinet_48"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_raw_output(raw_path, processed_path):\n",
    "    \"\"\"\n",
    "    Convert raw output json to {query: code} pairs\n",
    "\n",
    "    Raw output json:\n",
    "    [{\n",
    "        \"context\": context,\n",
    "        \"query\": use_query,\n",
    "        \"src_fs\": src_fs,\n",
    "        \"code_str\": code_str,\n",
    "        \"gvars\": list(gvars.keys()),\n",
    "        \"lvars\": list(lvars.keys()),\n",
    "    },\n",
    "    ...\n",
    "    ]\n",
    "    \"\"\"\n",
    "    with open(raw_path, 'r') as f:\n",
    "        raw_data = json.load(f)\n",
    "    \n",
    "    processed_data = []\n",
    "    for data in raw_data:\n",
    "        context = data['context']\n",
    "        query = data['query']\n",
    "        query = context + query\n",
    "\n",
    "        src_fs = data['src_fs']\n",
    "        code = data['code_str']\n",
    "        if len(src_fs) > 0:\n",
    "            fs_definition_str = '\\n'.join([v for k, v in src_fs.items()])\n",
    "            code = fs_definition_str + '\\n' + code\n",
    "        \n",
    "        processed_data.append({\n",
    "            \"query\": query,\n",
    "            \"code\": code\n",
    "        })\n",
    "\n",
    "    with open(processed_path, 'w') as f:\n",
    "        json.dump(processed_data, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finish processing ../../data/generated_code/raw_table_cabinet_0.json to ../../data/generated_code/processed_table_cabinet_0.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_1.json to ../../data/generated_code/processed_table_cabinet_1.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_2.json to ../../data/generated_code/processed_table_cabinet_2.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_3.json to ../../data/generated_code/processed_table_cabinet_3.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_4.json to ../../data/generated_code/processed_table_cabinet_4.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_6.json to ../../data/generated_code/processed_table_cabinet_6.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_9.json to ../../data/generated_code/processed_table_cabinet_9.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_10.json to ../../data/generated_code/processed_table_cabinet_10.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_11.json to ../../data/generated_code/processed_table_cabinet_11.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_13.json to ../../data/generated_code/processed_table_cabinet_13.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_14.json to ../../data/generated_code/processed_table_cabinet_14.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_15.json to ../../data/generated_code/processed_table_cabinet_15.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_16.json to ../../data/generated_code/processed_table_cabinet_16.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_17.json to ../../data/generated_code/processed_table_cabinet_17.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_21.json to ../../data/generated_code/processed_table_cabinet_21.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_26.json to ../../data/generated_code/processed_table_cabinet_26.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_28.json to ../../data/generated_code/processed_table_cabinet_28.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_29.json to ../../data/generated_code/processed_table_cabinet_29.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_30.json to ../../data/generated_code/processed_table_cabinet_30.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_31.json to ../../data/generated_code/processed_table_cabinet_31.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_33.json to ../../data/generated_code/processed_table_cabinet_33.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_34.json to ../../data/generated_code/processed_table_cabinet_34.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_38.json to ../../data/generated_code/processed_table_cabinet_38.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_39.json to ../../data/generated_code/processed_table_cabinet_39.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_41.json to ../../data/generated_code/processed_table_cabinet_41.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_42.json to ../../data/generated_code/processed_table_cabinet_42.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_44.json to ../../data/generated_code/processed_table_cabinet_44.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_45.json to ../../data/generated_code/processed_table_cabinet_45.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_46.json to ../../data/generated_code/processed_table_cabinet_46.json\n",
      "Finish processing ../../data/generated_code/raw_table_cabinet_47.json to ../../data/generated_code/processed_table_cabinet_47.json\n"
     ]
    }
   ],
   "source": [
    "gen_code_dir = \"../../data/generated_code\"\n",
    "\n",
    "# convert all raw_table_cabinet_i.json to processed_table_cabinet_i.json\n",
    "idx_start = 0\n",
    "idx_end = 47\n",
    "raw_files_idx_not_exist = []\n",
    "raw_files_idx_process_failed = []\n",
    "for i in range(idx_start, idx_end+1):\n",
    "    raw_path = os.path.join(gen_code_dir, f\"raw_table_cabinet_{i}.json\")\n",
    "    processed_path = os.path.join(gen_code_dir, f\"processed_table_cabinet_{i}.json\")\n",
    "    if not os.path.exists(raw_path):\n",
    "        raw_files_idx_not_exist.append(i)\n",
    "        continue\n",
    "    try:\n",
    "        process_raw_output(raw_path, processed_path)\n",
    "        print(f\"Finish processing {raw_path} to {processed_path}\")\n",
    "    except:\n",
    "        raw_files_idx_process_failed.append(i)\n",
    "        continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[5, 7, 8, 12, 18, 19, 20, 22, 23, 24, 25, 27, 32, 35, 36, 37, 40, 43]\n",
      "[]\n"
     ]
    }
   ],
   "source": [
    "print(raw_files_idx_not_exist)\n",
    "print(raw_files_idx_process_failed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of processed data: 4303\n"
     ]
    }
   ],
   "source": [
    "# calculate the number of processed data across all files\n",
    "idx_start = 0\n",
    "idx_end = 99\n",
    "num_processed_data_list = []\n",
    "for i in range(idx_start, idx_end+1):\n",
    "    processed_path = os.path.join(gen_code_dir, f\"processed_table_cabinet_{i}.json\")\n",
    "    if not os.path.exists(processed_path):\n",
    "        continue\n",
    "    with open(processed_path, 'r') as f:\n",
    "        processed_data = json.load(f)\n",
    "    num_processed_data_list.append(len(processed_data))\n",
    "\n",
    "num_processed_data = np.sum(num_processed_data_list)\n",
    "print(f\"Number of processed data: {num_processed_data}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "franka",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
